import csv
import os
from typing import List
import matplotlib.pyplot as plt


def load_trace(path: str) -> List[dict]:
	with open(path, "r", encoding="utf-8") as fh:
		reader = csv.DictReader(fh)
		return [row for row in reader]


def plot_reward(path: str, out_png: str):
	rows = load_trace(path)
	rewards = [float(r["reward"]) for r in rows]
	plt.figure(figsize=(8, 3))
	plt.plot(rewards, linewidth=1.0)
	plt.title(f"Reward per step — {os.path.basename(path)}")
	plt.xlabel("step")
	plt.ylabel("reward")
	os.makedirs(os.path.dirname(out_png), exist_ok=True)
	plt.tight_layout()
	plt.savefig(out_png, dpi=150)
	plt.close()


def plot_position(path: str, out_png: str):
	rows = load_trace(path)
	pos = [float(r["norm_pos"]) for r in rows]
	plt.figure(figsize=(8, 3))
	plt.plot(pos, linewidth=1.0)
	plt.title(f"Normalized position — {os.path.basename(path)}")
	plt.xlabel("step")
	plt.ylabel("norm_pos")
	os.makedirs(os.path.dirname(out_png), exist_ok=True)
	plt.tight_layout()
	plt.savefig(out_png, dpi=150)
	plt.close()


if __name__ == "__main__":
	# example usage
	plot_reward("logs/flat_traces.csv", "plots/flat_reward.png")
	plot_position("logs/flat_traces.csv", "plots/flat_pos.png")
	plot_reward("logs/hier_traces.csv", "plots/hier_reward.png")
	plot_position("logs/hier_traces.csv", "plots/hier_pos.png")


